{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0187c2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import pyceres\n",
    "import pycolmap\n",
    "import numpy as np\n",
    "from hloc.utils import viz_3d\n",
    "from copy import deepcopy\n",
    "import cv2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22e87330",
   "metadata": {},
   "source": [
    "## Setup the toy example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba54109b-79a0-4620-9679-ecbcb504ab33",
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample_rotation(max_=np.pi * 2):\n",
    "    aa = np.random.randn(3)\n",
    "    aa *= np.random.rand() * max_ / np.linalg.norm(aa)\n",
    "    return pycolmap.Rotation3d(aa)\n",
    "\n",
    "\n",
    "def error(i_from_w, j_from_w):\n",
    "    i_from_j = i_from_w * j_from_w.inverse()\n",
    "    return np.rad2deg(i_from_j.rotation.angle()), np.linalg.norm(i_from_j.translation)\n",
    "\n",
    "\n",
    "num = 20\n",
    "i_from_w = [\n",
    "    pycolmap.Rigid3d(sample_rotation(), np.random.rand(3) * 10) for _ in range(num)\n",
    "]\n",
    "i_from_j = [i_from_w[i] * i_from_w[(i + 1) % num].inverse() for i in range(num)]\n",
    "\n",
    "i_from_w_init = [\n",
    "    i_from_w[i] * pycolmap.Rigid3d(sample_rotation(np.pi / 5), np.random.rand(3))\n",
    "    for i in range(num)\n",
    "]\n",
    "i_from_w_init[0] = i_from_w[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1d71bcf",
   "metadata": {},
   "source": [
    "## PGO with relative poses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81913517",
   "metadata": {},
   "outputs": [],
   "source": [
    "i_from_w_opt = deepcopy(i_from_w_init)\n",
    "err_opt = np.array([error(i_from_w[i], i_from_w_opt[i]) for i in range(num)])\n",
    "print(\"initial:\", np.mean(err_opt, 0))\n",
    "\n",
    "prob = pyceres.Problem()\n",
    "loss = pyceres.TrivialLoss()\n",
    "costs = []\n",
    "for i in range(num):\n",
    "    cost = pycolmap.cost_functions.MetricRelativePoseErrorCost(i_from_j[i], np.eye(6))\n",
    "    costs.append(cost)\n",
    "    j = (i + 1) % num\n",
    "    params = [\n",
    "        i_from_w_opt[i].rotation.quat,\n",
    "        i_from_w_opt[i].translation,\n",
    "        i_from_w_opt[j].rotation.quat,\n",
    "        i_from_w_opt[j].translation,\n",
    "    ]\n",
    "    prob.add_residual_block(cost, loss, params)\n",
    "    prob.set_manifold(i_from_w_opt[i].rotation.quat, pyceres.EigenQuaternionManifold())\n",
    "prob.set_parameter_block_constant(i_from_w_opt[0].rotation.quat)\n",
    "prob.set_parameter_block_constant(i_from_w_opt[0].translation)\n",
    "\n",
    "options = pyceres.SolverOptions()\n",
    "options.linear_solver_type = pyceres.LinearSolverType.SPARSE_NORMAL_CHOLESKY\n",
    "options.minimizer_progress_to_stdout = False\n",
    "options.num_threads = -1\n",
    "summary = pyceres.SolverSummary()\n",
    "pyceres.solve(options, prob, summary)\n",
    "print(summary.BriefReport())\n",
    "\n",
    "err_opt = np.array([error(i_from_w[i], i_from_w_opt[i]) for i in range(num)])\n",
    "print(\"optimized:\", np.mean(err_opt, 0))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3b89ed97",
   "metadata": {},
   "source": [
    "# PGO with absolute pose"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4adde3a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "i_from_w_opt = deepcopy(i_from_w_init)\n",
    "\n",
    "err_opt = np.array([error(i_from_w[i], i_from_w_opt[i]) for i in range(num)])\n",
    "print(\"initial:\", np.mean(err_opt, 0))\n",
    "\n",
    "prob = pyceres.Problem()\n",
    "loss = pyceres.TrivialLoss()\n",
    "costs = []\n",
    "for i in range(num):\n",
    "    cost = pycolmap.cost_functions.AbsolutePoseErrorCost(i_from_w[i], np.eye(6))\n",
    "    costs.append(cost)\n",
    "    pose = i_from_w_opt[i]\n",
    "    prob.add_residual_block(cost, loss, [pose.rotation.quat, pose.translation])\n",
    "    prob.set_manifold(pose.rotation.quat, pyceres.EigenQuaternionManifold())\n",
    "\n",
    "options = pyceres.SolverOptions()\n",
    "options.linear_solver_type = pyceres.LinearSolverType.DENSE_QR\n",
    "options.minimizer_progress_to_stdout = False\n",
    "options.num_threads = -1\n",
    "summary = pyceres.SolverSummary()\n",
    "pyceres.solve(options, prob, summary)\n",
    "print(summary.BriefReport())\n",
    "\n",
    "err_opt = np.array([error(i_from_w[i], i_from_w_opt[i]) for i in range(num)])\n",
    "print(\"optimized:\", np.mean(err_opt, 0))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
